/*
Function to estimate weighted event study and return aggregated estimates.
*/

cap program drop aggregate_event_study
program define aggregate_event_study
    version 18.0
    syntax varlist [if], ///
	min_event(int) ///
	max_event(int) ///
	[x(string)] ///
	[save(string)] ///
	store(string) ///
	[restore] ///
	[percentage]
    
    if (`max_event' < 5) {
    	display "max event needs to be at least 5"
	exit
    }
    
    if (`min_event' > -3) {
    	display "min event needs to be at most -3"
    }
        
    if ("`restore'" == "") {
    	* run the fully interacted regression
	reghdfe `varlist' ///
	    dcohort* ///
	    treated_cohort* ///
	    exist_* ///
	    et_* ///
	    `x' ///
	    [aweight=X_wt] `if', ///
	    noabsorb vce(cluster lnr lfirm)
	if ("`store'" != "") {
            estimates store `store'
	}
    }
    else {
        estimates restore `store'
	estimates replay `store'
    }
    
    mat define est_full = r(table)
    
    tempvar sample_flag
    gen `sample_flag' = e(sample)
    
    * get sample means
    mat define sample_stats = J(3, 1, .)
    matrix colnames sample_stats = b
    matrix rownames sample_stats = dep_mean n_firms n_owners
    
    qui: sum `varlist' [aweight=X_wt] if `sample_flag' == 1
    matrix sample_stats[1, 1] = r(mean)
    
    * get sample sizes
    tempvar firm_flag
    tempvar lnr_flag
    bys lfirm `sample_flag' : gen `firm_flag' = (_n == 1)
    bys lnr `sample_flag' : gen `lnr_flag' = (_n == 1)
    
    qui: count if `sample_flag' == 1 & `firm_flag' == 1
    matrix sample_stats[2, 1] = r(N)
    
    qui: count if `sample_flag' == 1 & `lnr_flag' == 1
    matrix sample_stats[3, 1] = r(N)
    
    * identify first cohort used
    qui: sum cohort_id if `sample_flag' == 1
    local cohort_lb = r(min)
    local cohort_ub = r(max)
    
    * identify first/last years data is observed for outcome
    qui: sum year if `sample_flag' == 1
    local year_lb = r(min)
    local year_ub = r(max)
    
    * aggregate across estimates using variance, equal, or ATT weights
    local es_lincom
    foreach et of numlist `min_event'/`max_event' {
    	* coefs will be named e.g. m2 = -2, p2 = 2
        local eta = abs(`et')
        if `et' < 0 {
            local ets "m`eta'"
        }
        else {
            local ets "p`eta'"
        }
	
	* get cohort-specific weights for the event time period
	local et_wt = 0
        local et_sum "`ets' = (0"
        local agg_wt = 0
	
	if `et' == -2 {
            local et_sum "`ets' = 0"
	}
	else {
	    forvalues year = `cohort_lb'/`cohort_ub' {
	        if (`year' + `et' <= `year_ub') & ///
		     (`year' + `et' >= `year_lb') {
			* weight by size of treated population
			qui: count if `sample_flag' == 1 & year == `year' - 2 ///
			    & treated == 1 & cohort_id == `year'
			local et_wt = r(N)
		        local et_sum "`et_sum' + et_treat_c`year'_`ets' * `et_wt'"
		        local agg_wt = `agg_wt' + `et_wt'
		}
	    }
	    
	    * normalize by the sum of the weights (so that they sum to one)
	    local et_sum "`et_sum') / `agg_wt'"
	}
	
	local es_lincom "`es_lincom' (`et_sum')"
        
    }
    
    xlincom `es_lincom', post
    mat define est_es = r(table)
    
    * define the aggregates  
    local agg_lincom "p1" // will be average over [1, max event]

    forvalues ets = 2/`max_event' {
        local agg_lincom "`agg_lincom' + p`ets'"
    }
    
    local agg_lincom_0 "p0 + `agg_lincom'" // average over [0, max event]
    
    * get aggregates, test for equality
    xlincom ///
	(agg_post = (`agg_lincom') / `max_event') ///
	(agg_post_w0 = (`agg_lincom_0') / (1 + `max_event')) ///
    , post
    mat define est_agg = r(table)
    
    ****************************************
    * Estimate percentages
    ***********************
    if ("`percentage'" != "") {
    
        estimates restore `store'
	estimates replay `store'
	
	**** get E[\hat{Y}(0) | cohort, D = 1, t]
        local es_lincom_y0
	
        preserve
	* Since E[\hat{Y}(0) | cohort, D=1, t] = f(\hat{\beta})'E[X| cohort, D=1, t],
	* first collapse down to get E[X | cohort,D,t]
	keep if `sample_flag' == 1
        collapse (mean) exist_* treated_cohort* `x' ///
            [aweight=X_wt], by(cohort_id treated year)
	    
	* then loop over all cohorts and event times
        forvalues year = `cohort_lb'/`cohort_ub' {
            foreach et of numlist `min_event'/-3 -1/`max_event' {
	    	
            local eta = abs(`et')
            if `et' < 0 {
                local ets "m`eta'"
            }
            else {
    	        local ets "p`eta'"
            }  
		
            unab itervars : exist_c`year'* ///
	                    exist_treat_c`year'* ///
			    treated_cohort* ///
			    `x'
			    
            if (`year' + `et' <= `year_ub') & ///
		     (`year' + `et' >= `year_lb') {
		     	
		* perform the summation to get f(\hat{\beta})'E[X| cohort, D=1, t]
	        local y0_sum "(y0_c`year'_`ets' = _cons"
		local y0_sum "`y0_sum' + dcohort`year' + et_c`year'_`ets'"
		
		foreach iv of local itervars {
		    * E[X|cohort,D,t]
		    qui: sum `iv' if cohort_id == `year' & ///
		                     treated == 1 & ///
				     year == (`year' + `et')
		    if (`r(N)' != 0) {
		    if (`r(mean)' != 0) {
		    	local y0_sum "`y0_sum' + `r(mean)' * `iv'"
		    }
	            }
	
                } 
		
                local es_lincom_y0 "`es_lincom_y0' `y0_sum')"
	    }
	    }
        }
        restore

	xlincom `es_lincom_y0' `es_lincom', post
        mat define y0_var = e(V)
	
	******* get weighted average of E[\hat{Y}(0)|D=1,t]
        local y0_lincom
        local pct_nlcom
	
        foreach et of numlist `min_event'/`max_event' {
            local eta = abs(`et')
            if `et' < 0 {
                local ets "m`eta'"
            }
            else {
    	        local ets "p`eta'"
            }
	    
	    local pct_wt = 0
	    local pct_sum "y0_`ets' = (0"
	    local agg_pct = 0
	    
	    if `et' == -2 {
                local pct_nlcom ///
                    "`pct_nlcom' (pct_`ets' : 0)"
	    }
	    else {
	    forvalues year = `cohort_lb'/`cohort_ub' {
	        if (`year' + `et' <= `year_ub') & ///
		     (`year' + `et' >= `year_lb') {
		    qui: count if `sample_flag' == 1 & year == `year' - 2 ///
		        & treated == 1 & cohort_id == `year'
		    local pct_wt = r(N)
		    local pct_sum "`pct_sum' + y0_c`year'_`ets' * `pct_wt'"
		    local agg_pct = `agg_pct' + `pct_wt'
	        }
	    }
	    
            local pct_sum "`pct_sum') / `agg_pct'"
            local y0_lincom "`y0_lincom' (`pct_sum') (`ets' = `ets')"
            local pct_nlcom "`pct_nlcom' (pct_`ets' : _b[`ets'] / _b[y0_`ets'])"
	    }
        }
	
        xlincom `y0_lincom', post
        nlcom `pct_nlcom', post
        mat define est_pct = r(table)
	
        * define the aggregates  
        local agg_pct "pct_p1" // will be average over [1, max event]
    
        forvalues ets = 2/`max_event' {
            local agg_pct "`agg_pct' + pct_p`ets'"    	
        }
        
        local agg_pct_0 "pct_p0 + `agg_pct'" // average over [0, max event]
        
        * get aggregates, test for equality
        xlincom ///
    	    (pct_agg_post = (`agg_pct') / `max_event') ///
    	    (pct_agg_post_w0 = (`agg_pct_0') / (1 + `max_event')) ///
        , post
        mat define est_pct_agg = r(table)
    }
         
    * save the estimates
    if ("`save'" != "") {
        preserve
        tempfile tempsave
        local flag = 0
	if ("`percentage'" != "") {
	    local estmats est_full est_es est_agg est_pct est_pct_agg
	}
	else {
	    local estmats est_full est_es est_agg
	}
        foreach estmat in `estmats' {
            clear
            svmat `estmat'
            gen id = _n 
            reshape long `estmat', i(id) j(event_time)
            reshape wide `estmat', i(event_time) j(id)
            local RN : rownames `estmat'
            ren (`estmat'*) (`RN')
            local CN : colnames `estmat'
            local i = 0
            gen X = ""
            foreach var in `CN' {
                local i = `i' + 1
                replace X = "`var'" if _n == `i'
            }
            gen aggregation = "`estmat'"
            if `flag' == 0 {
                save `tempsave', replace
                local flag = 1
            }
            else {
                append using `tempsave'
                save `tempsave', replace
            } 
        }
	
	* append the sample statistics
	clear
	svmat sample_stats
	ren sample_stats1 b
	local RN : rownames sample_stats
        local i = 0
        gen X = ""
        foreach var in `RN' {
            local i = `i' + 1
            replace X = "`var'" if _n == `i'
        }
	gen aggregation = "summary"
	append using `tempsave'

        export delimited using `save', replace
        restore
    }
end
